Code
import torch
import numpy as np
August 24, 2022
This is a small demo of the how the ground truths and loss will look in centerNet. Most of the code is from MMdetection. The idea here is to show a demo part so that the code is understandable. My repo for understanding the architecure can be found here . To understand the architecture please go through the blog by Shreejal Trivedi.
def gaussian2D(radius, sigma=1, dtype=torch.float32, device='cpu'):
"""Generate 2D gaussian kernel.
Args:
radius (int): Radius of gaussian kernel.
sigma (int): Sigma of gaussian function. Default: 1.
dtype (torch.dtype): Dtype of gaussian tensor. Default: torch.float32.
device (str): Device of gaussian tensor. Default: 'cpu'.
Returns:
h (Tensor): Gaussian kernel with a
``(2 * radius + 1) * (2 * radius + 1)`` shape.
"""
x = torch.arange(
-radius, radius + 1, dtype=dtype, device=device).view(1, -1)
y = torch.arange(
-radius, radius + 1, dtype=dtype, device=device).view(-1, 1)
h = (-(x * x + y * y) / (2 * sigma * sigma)).exp()
h[h < torch.finfo(h.dtype).eps * h.max()] = 0
return h
def gen_gaussian_target(heatmap, center, radius, k=1):
"""Generate 2D gaussian heatmap.
Args:
heatmap (Tensor): Input heatmap, the gaussian kernel will cover on
it and maintain the max value.
center (list[int]): Coord of gaussian kernel's center.
radius (int): Radius of gaussian kernel.
k (int): Coefficient of gaussian kernel. Default: 1.
Returns:
out_heatmap (Tensor): Updated heatmap covered by gaussian kernel.
"""
diameter = 2 * radius + 1
gaussian_kernel = gaussian2D(
radius, sigma=diameter / 6, dtype=heatmap.dtype, device=heatmap.device)
x, y = center
height, width = heatmap.shape[:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian_kernel[radius - top:radius + bottom,
radius - left:radius + right]
out_heatmap = heatmap
torch.max(
masked_heatmap,
masked_gaussian * k,
out=out_heatmap[y - top:y + bottom, x - left:x + right])
return out_heatmap
def gaussian_radius(det_size, min_overlap):
r"""Generate 2D gaussian radius.
This function is modified from the `official github repo
<https://github.com/princeton-vl/CornerNet-Lite/blob/master/core/sample/
utils.py#L65>`_.
Given ``min_overlap``, radius could computed by a quadratic equation
according to Vieta's formulas.
There are 3 cases for computing gaussian radius, details are following:
- Explanation of figure: ``lt`` and ``br`` indicates the left-top and
bottom-right corner of ground truth box. ``x`` indicates the
generated corner at the limited position when ``radius=r``.
- Case1: one corner is inside the gt box and the other is outside.
.. code:: text
|< width >|
lt-+----------+ -
| | | ^
+--x----------+--+
| | | |
| | | | height
| | overlap | |
| | | |
| | | | v
+--+---------br--+ -
| | |
+----------+--x
To ensure IoU of generated box and gt box is larger than ``min_overlap``:
.. math::
\cfrac{(w-r)*(h-r)}{w*h+(w+h)r-r^2} \ge {iou} \quad\Rightarrow\quad
{r^2-(w+h)r+\cfrac{1-iou}{1+iou}*w*h} \ge 0 \\
{a} = 1,\quad{b} = {-(w+h)},\quad{c} = {\cfrac{1-iou}{1+iou}*w*h}
{r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a}
- Case2: both two corners are inside the gt box.
.. code:: text
|< width >|
lt-+----------+ -
| | | ^
+--x-------+ |
| | | |
| |overlap| | height
| | | |
| +-------x--+
| | | v
+----------+-br -
To ensure IoU of generated box and gt box is larger than ``min_overlap``:
.. math::
\cfrac{(w-2*r)*(h-2*r)}{w*h} \ge {iou} \quad\Rightarrow\quad
{4r^2-2(w+h)r+(1-iou)*w*h} \ge 0 \\
{a} = 4,\quad {b} = {-2(w+h)},\quad {c} = {(1-iou)*w*h}
{r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a}
- Case3: both two corners are outside the gt box.
.. code:: text
|< width >|
x--+----------------+
| | |
+-lt-------------+ | -
| | | | ^
| | | |
| | overlap | | height
| | | |
| | | | v
| +------------br--+ -
| | |
+----------------+--x
To ensure IoU of generated box and gt box is larger than ``min_overlap``:
.. math::
\cfrac{w*h}{(w+2*r)*(h+2*r)} \ge {iou} \quad\Rightarrow\quad
{4*iou*r^2+2*iou*(w+h)r+(iou-1)*w*h} \le 0 \\
{a} = {4*iou},\quad {b} = {2*iou*(w+h)},\quad {c} = {(iou-1)*w*h} \\
{r} \le \cfrac{-b+\sqrt{b^2-4*a*c}}{2*a}
Args:
det_size (list[int]): Shape of object.
min_overlap (float): Min IoU with ground truth for boxes generated by
keypoints inside the gaussian kernel.
Returns:
radius (int): Radius of gaussian kernel.
"""
height, width = det_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = sqrt(b1**2 - 4 * a1 * c1)
r1 = (b1 - sq1) / (2 * a1)
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = sqrt(b2**2 - 4 * a2 * c2)
r2 = (b2 - sq2) / (2 * a2)
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = sqrt(b3**2 - 4 * a3 * c3)
r3 = (b3 + sq3) / (2 * a3)
return min(r1, r2, r3)
tensor([[0.0183, 0.0821, 0.1353, 0.0821, 0.0183],
[0.0821, 0.3679, 0.6065, 0.3679, 0.0821],
[0.1353, 0.6065, 1.0000, 0.6065, 0.1353],
[0.0821, 0.3679, 0.6065, 0.3679, 0.0821],
[0.0183, 0.0821, 0.1353, 0.0821, 0.0183]])
Our radius was 2 so we can see that at (2,2) the magnitude is 1 and in a gaussian kernel way,it decreases around.Now this heatmap will be copied to the center as required, but if the center is at the corners then cropping of the heatmap might be requried as needed 1. As in the begining assume that the heatmap is of shape 8,8 and lets assume that the object is located at the center (3,3). 2. So we need to copy the ground truth heatMap to that position as is as shown in the code below
radius = 2
heatmap = torch.zeros((8,8))
height, width = heatmap.shape[:2]
center=(3,3)
x, y = center
# we are doing this because the kernel may lie outside the heatmap for example near corners
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = h[radius - top:radius + bottom,
radius - left:radius + right]
tensor([[0.0183, 0.0821, 0.1353, 0.0821, 0.0183],
[0.0821, 0.3679, 0.6065, 0.3679, 0.0821],
[0.1353, 0.6065, 1.0000, 0.6065, 0.1353],
[0.0821, 0.3679, 0.6065, 0.3679, 0.0821],
[0.0183, 0.0821, 0.1353, 0.0821, 0.0183]])
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
out_heatmap = heatmap
torch.max(
masked_heatmap,
masked_gaussian ,
out=out_heatmap[y - top:y + bottom, x - left:x + right])
tensor([[0.0183, 0.0821, 0.1353, 0.0821, 0.0183],
[0.0821, 0.3679, 0.6065, 0.3679, 0.0821],
[0.1353, 0.6065, 1.0000, 0.6065, 0.1353],
[0.0821, 0.3679, 0.6065, 0.3679, 0.0821],
[0.0183, 0.0821, 0.1353, 0.0821, 0.0183]])
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0183, 0.0821, 0.1353, 0.0821, 0.0183, 0.0000, 0.0000],
[0.0000, 0.0821, 0.3679, 0.6065, 0.3679, 0.0821, 0.0000, 0.0000],
[0.0000, 0.1353, 0.6065, 1.0000, 0.6065, 0.1353, 0.0000, 0.0000],
[0.0000, 0.0821, 0.3679, 0.6065, 0.3679, 0.0821, 0.0000, 0.0000],
[0.0000, 0.0183, 0.0821, 0.1353, 0.0821, 0.0183, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
We can see that the heatmap has been placed with a value of 1 as (3,3) which is the center we gave.
Now look at how the losses will work.
groundTruth = torch.zeros((2,4,8,8),dtype=torch.float32)
print("GroundTruth shape",groundTruth.shape)
# now we need to copy the heat map we generated above to the positions of the class ids,
# here we have assumed the in the first image the class id 0 is having the bounding box
# and in the image 2 the classid 2 is having the object, for simplicity we are assuming
# that both the images have same heat map an center, the assignment is as follows then
groundTruth[0,0,:,:] = out_heatmap.clone()
groundTruth[1,2,:,:] = out_heatmap.clone()
print("Ground Truth \n",groundTruth)
GroundTruth shape torch.Size([2, 4, 8, 8])
Ground Truth
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0183, 0.0821, 0.1353, 0.0821, 0.0183, 0.0000, 0.0000],
[0.0000, 0.0821, 0.3679, 0.6065, 0.3679, 0.0821, 0.0000, 0.0000],
[0.0000, 0.1353, 0.6065, 1.0000, 0.6065, 0.1353, 0.0000, 0.0000],
[0.0000, 0.0821, 0.3679, 0.6065, 0.3679, 0.0821, 0.0000, 0.0000],
[0.0000, 0.0183, 0.0821, 0.1353, 0.0821, 0.0183, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0183, 0.0821, 0.1353, 0.0821, 0.0183, 0.0000, 0.0000],
[0.0000, 0.0821, 0.3679, 0.6065, 0.3679, 0.0821, 0.0000, 0.0000],
[0.0000, 0.1353, 0.6065, 1.0000, 0.6065, 0.1353, 0.0000, 0.0000],
[0.0000, 0.0821, 0.3679, 0.6065, 0.3679, 0.0821, 0.0000, 0.0000],
[0.0000, 0.0183, 0.0821, 0.1353, 0.0821, 0.0183, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
Now we will make a random prediction and see how we calculate the losses
From the heat map we can see that the negative samples will be much more and therefore they introduced this modified version of focal loss to counteract that
#so cx and cy will be
ctx,cty = 28/8 ,28/8
ctx_int,cty_int = int(ctx),int(cty)
print("orginal scaled down ",(ctx, cty) )
print("floored version " ,(ctx_int,cty_int))
print("offset is ",(ctx- ctx_int,cty-cty_int))
orginal scaled down (3.5, 3.5)
floored version (3, 3)
offset is (0.5, 0.5)
# so for the we have said that the object is same position in both the images, so when
# groundTruth is set the to be predicted width and height at the same position
groundTruthWHOffset[0,0,ctx_int,cty_int] = ctx- ctx_int
groundTruthWHOffset[0,1,ctx_int,cty_int] = cty-cty_int
# we are asuming the object is at the same position in both the images so the
# above will be the same for batchid 1
groundTruthWHOffset[1,0,ctx_int,cty_int] = ctx- ctx_int
groundTruthWHOffset[1,1,ctx_int,cty_int] = cty-cty_int
We need set weights because we need to consider loss only from the places there was an object
Lets make a random prediction to calculate the loss
The loss we use for the wh and wh_offset are the same and is the l1_loss
def l1_loss(pred, target):
"""L1 loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
Returns:
torch.Tensor: Calculated loss
"""
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
loss = torch.abs(pred - target)
return loss
#Note we are multiplying by the weights in the end to get the loss from required poistion only
WHLoss = l1_loss(predWH,groundTruthWH)*groundTruthWHOffsetWeights
WHOffsetLoss = l1_loss(predWHOffset,groundTruthWHOffset)*groundTruthWHOffsetWeights
WHLoss = WHLoss.sum()
WHOffsetLoss = WHOffsetLoss.sum()
Ground Truth Loss tensor(3.6995)
Ground Offset Loss tensor(2.0938)
The final loss is the weighted sum of the heapmap loss, wh loss and wh_offset loss.There is a little more things involved and in the repo i have showed how these are actually done in a real implemenation. Hope this was helpful